{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {meth}`~tvm.arith.analyzer.Analyzer.simplify`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/read/arith\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "import testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import tir\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.simplify` reshape_flattened"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ana = tvm.arith.Analyzer()\n",
    "\n",
    "i0 = tir.Var(\"i0\", \"int64\")\n",
    "i1 = tir.Var(\"i1\", \"int64\")\n",
    "ana.bind(i0, tvm.ir.Range(0, 8))\n",
    "ana.bind(i1, tvm.ir.Range(0, 3))\n",
    "\n",
    "i_flattened = i0 * 3 + i1\n",
    "assert tvm.ir.structural_equal(\n",
    "    ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),\n",
    "    i_flattened,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.simplify` symbolic_comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ana = tvm.arith.Analyzer()\n",
    "\n",
    "i0 = tir.Var(\"i0\", \"int64\")\n",
    "i1 = tir.Var(\"i1\", \"int64\")\n",
    "n, m = tvm.tir.SizeVar(\"n\", \"int64\"), tvm.tir.SizeVar(\"m\", \"int64\")\n",
    "outer = (n + 31) // 32\n",
    "ana.bind(i0, tvm.ir.Range(0, outer))\n",
    "ana.bind(i1, tvm.ir.Range(0, 32))\n",
    "PS = tvm.arith.ProofStrength\n",
    "\n",
    "assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT)\n",
    "assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)\n",
    "assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)\n",
    "assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)\n",
    "assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)\n",
    "assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.simplify` vscale_comparison_with_sve_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for expression in [\n",
    "    T.vscale() * 32 < T.vscale() * 64,\n",
    "    T.vscale() * 2 * (T.vscale() * 2) >= T.vscale() * 4,\n",
    "    (T.vscale() * 4 + 114) // (T.vscale() * 4) * (T.vscale() * 4) >= 115,\n",
    "    64 % T.vscale() <= T.vscale(),\n",
    "]:\n",
    "    ana = tvm.arith.Analyzer()\n",
    "\n",
    "    with tvm.target.Target(\"llvm -mtriple=aarch64-linux-gnu -mattr=+sve\"):\n",
    "        assert ana.can_prove(expression)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.simplify` vscale_comparison_without_sve_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09:35:26] /media/pc/data/lxw/ai/tvm/src/arith/analyzer.cc:240: Warning: The expression contains scalable values. An attempt to prove by substituting with known values of vscale was not performed. This proof currently only supports AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m vs \u001b[38;5;241m=\u001b[39m tvm\u001b[38;5;241m.\u001b[39mtir\u001b[38;5;241m.\u001b[39mvscale()\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m tvm\u001b[38;5;241m.\u001b[39mtarget\u001b[38;5;241m.\u001b[39mTarget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllvm -mtriple=aarch64-linux-gnu\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m----> 4\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m ana\u001b[38;5;241m.\u001b[39mcan_prove(vs \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m32\u001b[39m \u001b[38;5;241m<\u001b[39m vs \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m64\u001b[39m)\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "ana = tvm.arith.Analyzer()\n",
    "vs = tvm.tir.vscale()\n",
    "with tvm.target.Target(\"llvm -mtriple=aarch64-linux-gnu\"):\n",
    "    assert ana.can_prove(vs * 32 < vs * 64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.simplify` vscale_non_comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "InternalError",
     "evalue": "Traceback (most recent call last):\n  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::arith::__mk_TVM0::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue) const::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)#1}::operator()(std::allocator<char>) const::{lambda(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)#11}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)\n  1: tvm::arith::Analyzer::CanProve(tvm::PrimExpr const&, tvm::arith::ProofStrength)\n  0: tvm::arith::CanProveVscaleExpressionFromKnownValues(tvm::arith::Analyzer*, tvm::PrimExpr const&, std::vector<unsigned int, std::allocator<unsigned int> > const&)\n  File \"/media/pc/data/lxw/ai/tvm/src/arith/scalable_expression.cc\", line 82\nInternalError: Check failed: (IsComparison(expr)) is false: Expected comparison but got: T.vscale() * 4",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mInternalError\u001b[0m                             Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[9], line 5\u001b[0m\n\u001b[1;32m      2\u001b[0m vs \u001b[38;5;241m=\u001b[39m tvm\u001b[38;5;241m.\u001b[39mtir\u001b[38;5;241m.\u001b[39mvscale()\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m tvm\u001b[38;5;241m.\u001b[39mtarget\u001b[38;5;241m.\u001b[39mTarget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllvm -mtriple=aarch64-linux-gnu -mattr=+sve\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m----> 5\u001b[0m     \u001b[43mana\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcan_prove\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/arith/analyzer.py:247\u001b[0m, in \u001b[0;36mAnalyzer.can_prove\u001b[0;34m(self, expr, strength)\u001b[0m\n\u001b[1;32m    231\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcan_prove\u001b[39m(\u001b[38;5;28mself\u001b[39m, expr, strength\u001b[38;5;241m=\u001b[39mProofStrength\u001b[38;5;241m.\u001b[39mDEFAULT):\n\u001b[1;32m    232\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Check whether we can prove expr to be true.\u001b[39;00m\n\u001b[1;32m    233\u001b[0m \n\u001b[1;32m    234\u001b[0m \u001b[38;5;124;03m    Parameters\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    245\u001b[0m \u001b[38;5;124;03m        The result.\u001b[39;00m\n\u001b[1;32m    246\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 247\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_can_prove\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrength\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py:239\u001b[0m, in \u001b[0;36mPackedFuncBase.__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m    227\u001b[0m ret_tcode \u001b[38;5;241m=\u001b[39m ctypes\u001b[38;5;241m.\u001b[39mc_int()\n\u001b[1;32m    228\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m    229\u001b[0m     _LIB\u001b[38;5;241m.\u001b[39mTVMFuncCall(\n\u001b[1;32m    230\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandle,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    237\u001b[0m     \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    238\u001b[0m ):\n\u001b[0;32m--> 239\u001b[0m     \u001b[43mraise_last_ffi_error\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    240\u001b[0m _ \u001b[38;5;241m=\u001b[39m temp_args\n\u001b[1;32m    241\u001b[0m _ \u001b[38;5;241m=\u001b[39m args\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:481\u001b[0m, in \u001b[0;36mraise_last_ffi_error\u001b[0;34m()\u001b[0m\n\u001b[1;32m    475\u001b[0m \u001b[38;5;66;03m# The exception PyObject may contain a large amount of state,\u001b[39;00m\n\u001b[1;32m    476\u001b[0m \u001b[38;5;66;03m# including all stack frames that may be inspected in a later\u001b[39;00m\n\u001b[1;32m    477\u001b[0m \u001b[38;5;66;03m# PDB post-mortem.  Therefore, we must make sure to remove the\u001b[39;00m\n\u001b[1;32m    478\u001b[0m \u001b[38;5;66;03m# underlying PyObject* from the C++ side after we retrieve it.\u001b[39;00m\n\u001b[1;32m    479\u001b[0m _LIB\u001b[38;5;241m.\u001b[39mTVMDropLastPythonError()\n\u001b[0;32m--> 481\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m py_err\n",
      "\u001b[0;31mInternalError\u001b[0m: Traceback (most recent call last):\n  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::arith::__mk_TVM0::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue) const::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)#1}::operator()(std::allocator<char>) const::{lambda(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)#11}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)\n  1: tvm::arith::Analyzer::CanProve(tvm::PrimExpr const&, tvm::arith::ProofStrength)\n  0: tvm::arith::CanProveVscaleExpressionFromKnownValues(tvm::arith::Analyzer*, tvm::PrimExpr const&, std::vector<unsigned int, std::allocator<unsigned int> > const&)\n  File \"/media/pc/data/lxw/ai/tvm/src/arith/scalable_expression.cc\", line 82\nInternalError: Check failed: (IsComparison(expr)) is false: Expected comparison but got: T.vscale() * 4"
     ]
    }
   ],
   "source": [
    "ana = tvm.arith.Analyzer()\n",
    "vs = tvm.tir.vscale()\n",
    "\n",
    "with tvm.target.Target(\"llvm -mtriple=aarch64-linux-gnu -mattr=+sve\"):\n",
    "    ana.can_prove(vs * 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {meth}`~tvm.arith.analyzer.Analyzer.simplify` regression_simplify_inf_recursion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "T.Cast(\"int32\", T.Cast(\"int8\", cond != 0) - T.Cast(\"int8\", cond != 0)) == 0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ana = tvm.arith.Analyzer()\n",
    "cond = tir.Var(\"cond\", \"int32\")\n",
    "\n",
    "res = (tvm.tir.NE(cond, 0).astype(\"int8\") - tvm.tir.NE(cond, 0).astype(\"int8\")).astype(\n",
    "    \"int32\"\n",
    ") == 0\n",
    "# regression in a previous case\n",
    "# try compare and int set recursive call can cause infinite loop\n",
    "ana.rewrite_simplify(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
